{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/nicolas-chaulet/torch-points3d/blob/master/notebooks/PartSegmentationKPConv.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sMPGNDv1vDn0"
   },
   "outputs": [],
   "source": [
    "# Setup packages can take some time (30 minutes or so)\n",
    "!pip install pyvista pytorch-lightning\n",
    "!pip install --upgrade jsonschema\n",
    "!pip install torch-points3d\n",
    "!apt-get install -qq xvfb libgl1-mesa-glx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "7_9hcYlFvOv6",
    "outputId": "9621ae8a-40c5-4790-cc11-6975d675a8da"
   },
   "outputs": [],
   "source": [
    "# Needed for remote rendering \n",
    "import os\n",
    "os.environ[\"DISPLAY\"] = \":1.0\"\n",
    "os.environ[\"PYVISTA_OFF_SCREEN\"]=\"true\"\n",
    "os.environ[\"PYVISTA_PLOT_THEME\"]=\"true\"\n",
    "os.environ[\"PYVISTA_USE_PANEL\"]=\"true\"\n",
    "os.environ[\"PYVISTA_AUTO_CLOSE\"]=\"false\"\n",
    "os.system(\"Xvfb :1 -screen 0 1024x768x24 > /dev/null 2>&1 &\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HGza3XTYvQNC"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from omegaconf import OmegaConf\n",
    "import pyvista as pv\n",
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Q44YjGjrvg2j"
   },
   "outputs": [],
   "source": [
    "DIR = \"\" # Replace with your root directory, the data will go in DIR/data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "izN8KU_tv5JR"
   },
   "source": [
    "<p align=\"center\">\n",
    "  <img width=\"40%\" src=\"https://raw.githubusercontent.com/nicolas-chaulet/torch-points3d/master/docs/logo.png\" />\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Mgox3ksbv-s1"
   },
   "source": [
    "# Segmenting objects in part with KPConv\n",
    "In this notebook we will solve the task of segmenting an object into its sub parts by using a [KPConv](https://arxiv.org/abs/1904.08889) deep neural network.\n",
    "We will work on [ShapeNet](https://www.shapenet.org/) dataset which contains 48,600 3D models over 55 common categories with part annotations. We will show you how you can use Torch Points3D to setup a KPConv backbone with a multi head classifier and train it on ShapeNet. We will in particular cover the CPU pre processing of the data that allows a complex KPConv to perform well."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wyFtUQhGybFc"
   },
   "source": [
    "## The dataset\n",
    "We use Torch Points3D version of ShapeNet that provides automatic download (be patient, it takes some time...) of the data, a tested metric tracker as well as methods for pre computing the spatial operations such as neighbour search and grid sampling on CPU.\n",
    "\n",
    "Let's start with the data config (if you want more details about that part of Torch Points3D please refer to this notebook)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "JnLlPJ_dmzOV"
   },
   "outputs": [],
   "source": [
    "#@title Configure the dataset {run: \"auto\"}\n",
    "CATEGORY = \"All\" #@param [\"Airplane\", \"Bag\", \"All\", \"Motorbike\"] {allow-input: true}\n",
    "USE_NORMALS = True #@param {type:\"boolean\"}|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PBRS6vxSvjRo"
   },
   "outputs": [],
   "source": [
    "shapenet_yaml = \"\"\"\n",
    "class: shapenet.ShapeNetDataset\n",
    "task: segmentation\n",
    "dataroot: %s\n",
    "normal: %r                                    # Use normal vectors as features\n",
    "first_subsampling: 0.02                       # Grid size of the input data\n",
    "pre_transforms:                               # Offline transforms, done only once\n",
    "    - transform: NormalizeScale           \n",
    "    - transform: GridSampling3D\n",
    "      params:\n",
    "        size: ${first_subsampling}\n",
    "train_transforms:                             # Data augmentation pipeline\n",
    "    - transform: RandomNoise\n",
    "      params:\n",
    "        sigma: 0.01\n",
    "        clip: 0.05\n",
    "    - transform: RandomScaleAnisotropic\n",
    "      params:\n",
    "        scales: [0.9,1.1]\n",
    "    - transform: AddOnes\n",
    "    - transform: AddFeatsByKeys\n",
    "      params:\n",
    "        list_add_to_x: [True]\n",
    "        feat_names: [\"ones\"]\n",
    "        delete_feats: [True]\n",
    "test_transforms:\n",
    "    - transform: AddOnes\n",
    "    - transform: AddFeatsByKeys\n",
    "      params:\n",
    "        list_add_to_x: [True]\n",
    "        feat_names: [\"ones\"]\n",
    "        delete_feats: [True]\n",
    "\"\"\" % (os.path.join(DIR,\"data\"), USE_NORMALS) \n",
    "\n",
    "from omegaconf import OmegaConf\n",
    "params = OmegaConf.create(shapenet_yaml)\n",
    "if CATEGORY != \"All\":\n",
    "    params.category = CATEGORY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 719,
     "referenced_widgets": [
      "20b156698ad1434cb5c2a4784da17469",
      "78d669df9595452ba48e37fac70b4982",
      "c4441c424d4c49958377ea28acdc85d4",
      "d85fb13a296e47bdaa84f34c70b73d81",
      "3086aa480b0a4006a92d536f1080bcc0",
      "4845349a7244472cbdfc2f13d49309b5",
      "abf37646a379440fbeae96e7315b01ae",
      "1f823b0dbcaa4ca8b750095c57848a38",
      "eb4b9d4a0f9d419caf308c9b684aa4f8",
      "8946a00597a84d16a750946fc34c5ed9",
      "5ae191c92d7c4ce5b56d04d0863a64b0",
      "bbafde9aa6e140348450ecc18efd1ba0",
      "cba3f22a13b745909d373fe2ff63d340",
      "45ad74d501f9493fb81347da950bf966",
      "a84f426e9a4f4847acb2d5d8a0bfff01",
      "856b111759f34b279e7dcc9cc63b136c",
      "a38dc43c225248019de3debec8132a53",
      "8bccc35bc8ff4461ab4e25fbd50e241e",
      "a67c79d4c5294c369a59976bbb2b8c73",
      "5920e0d0221744b08c89575f5ea640a4",
      "8875f62a0ffe4e3c9c39b3c9cf65264e",
      "cbb0fd9d7ca24e39be03290d69b4e106",
      "514041c0737f459f9a4a5f8c006cdd0e",
      "b27eabe1c9634e529bf5384b7ee675d1"
     ]
    },
    "colab_type": "code",
    "id": "eQFRmfUYzXX1",
    "outputId": "14c20c6a-da1d-484c-e393-255206b708a8"
   },
   "outputs": [],
   "source": [
    "# The first time you run this cell, it will download the dataset \n",
    "from torch_points3d.datasets.segmentation import ShapeNetDataset\n",
    "dataset = ShapeNetDataset(params)\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "colab_type": "code",
    "id": "wAN1JCiU1pFW",
    "outputId": "18443e20-2447-4ff5-b2b9-fe6f5d66205b"
   },
   "outputs": [],
   "source": [
    "#@title Plot samples with part annotations { run: \"auto\" }\n",
    "objectid_1 = 9 #@param {type:\"slider\", min:0, max:100, step:1}\n",
    "objectid_2 = 82 #@param {type:\"slider\", min:0, max:100, step:1}\n",
    "objectid_3 = 95 #@param {type:\"slider\", min:0, max:100, step:1}\n",
    "\n",
    "samples = [objectid_1,objectid_2,objectid_3]\n",
    "p = pv.Plotter(notebook=True,shape=(1, len(samples)),window_size=[1024,412])\n",
    "for i in range(len(samples)):\n",
    "    p.subplot(0, i)\n",
    "    sample = dataset.train_dataset[samples[i]]\n",
    "    point_cloud = pv.PolyData(sample.pos.numpy())\n",
    "    point_cloud['y'] = sample.y.numpy()\n",
    "    p.add_points(point_cloud,  show_scalar_bar=False, point_size=3)\n",
    "    p.camera_position = [-1,5, -10]\n",
    "p.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sLSGJVf30G1C"
   },
   "source": [
    "## Model for part segmentation\n",
    "Let's start by creating a multihead segmentation module with one segmentation head per category. We provide that as part of Torch Points3D but let's reproduce it here for sake of completeness."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "V4k79sQhzryp"
   },
   "outputs": [],
   "source": [
    "from torch_points3d.core.common_modules import MLP, UnaryConv\n",
    "\n",
    "class MultiHeadClassifier(torch.nn.Module):\n",
    "    \"\"\" Allows segregated segmentation in case the category of an object is known. \n",
    "    This is the case in ShapeNet for example.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    in_features -\n",
    "        size of the input channel\n",
    "    cat_to_seg\n",
    "        category to segment maps for example:\n",
    "        {\n",
    "            'Airplane': [0,1,2],\n",
    "            'Table': [3,4]\n",
    "        }\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_features, cat_to_seg, dropout_proba=0.5, bn_momentum=0.1):\n",
    "        super().__init__()\n",
    "        self._cat_to_seg = {}\n",
    "        self._num_categories = len(cat_to_seg)\n",
    "        self._max_seg_count = 0\n",
    "        self._max_seg = 0\n",
    "        self._shifts = torch.zeros((self._num_categories,), dtype=torch.long)\n",
    "        for i, seg in enumerate(cat_to_seg.values()):\n",
    "            self._max_seg_count = max(self._max_seg_count, len(seg))\n",
    "            self._max_seg = max(self._max_seg, max(seg))\n",
    "            self._shifts[i] = min(seg)\n",
    "            self._cat_to_seg[i] = seg\n",
    "\n",
    "        self.channel_rasing = MLP(\n",
    "            [in_features, self._num_categories * in_features], bn_momentum=bn_momentum, bias=False\n",
    "        )\n",
    "        if dropout_proba:\n",
    "            self.channel_rasing.add_module(\"Dropout\", torch.nn.Dropout(p=dropout_proba))\n",
    "\n",
    "        self.classifier = UnaryConv((self._num_categories, in_features, self._max_seg_count))\n",
    "        self._bias = torch.nn.Parameter(torch.zeros(self._max_seg_count,))\n",
    "\n",
    "    def forward(self, features, category_labels, **kwargs):\n",
    "        assert features.dim() == 2\n",
    "        self._shifts = self._shifts.to(features.device)\n",
    "        in_dim = features.shape[-1]\n",
    "        features = self.channel_rasing(features)\n",
    "        features = features.reshape((-1, self._num_categories, in_dim))\n",
    "        features = features.transpose(0, 1)  # [num_categories, num_points, in_dim]\n",
    "        features = self.classifier(features) + self._bias  # [num_categories, num_points, max_seg]\n",
    "        ind = category_labels.unsqueeze(-1).repeat(1, 1, features.shape[-1]).long()\n",
    "\n",
    "        logits = features.gather(0, ind).squeeze(0)\n",
    "        softmax = torch.nn.functional.log_softmax(logits, dim=-1)\n",
    "\n",
    "        output = torch.zeros(logits.shape[0], self._max_seg + 1).to(features.device)\n",
    "        cats_in_batch = torch.unique(category_labels)\n",
    "        for cat in cats_in_batch:\n",
    "            cat_mask = category_labels == cat\n",
    "            seg_indices = self._cat_to_seg[cat.item()]\n",
    "            probs = softmax[cat_mask, : len(seg_indices)]\n",
    "            output[cat_mask, seg_indices[0] : seg_indices[-1] + 1] = probs\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vk2LyH-IT2KG"
   },
   "source": [
    "The model we implement here follows the main architecture proposed in the [original paper](https://arxiv.org/abs/1904.08889):\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"70%\" src=\"https://drive.google.com/uc?export=view&id=1CJppQ88T69whjYsJc016L3_E_rtcJ8n1\" />\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZnyHC6Gx1nIw"
   },
   "outputs": [],
   "source": [
    "from torch_points3d.applications.kpconv import KPConv\n",
    "\n",
    "\n",
    "class PartSegKPConv(torch.nn.Module):\n",
    "    def __init__(self, cat_to_seg):\n",
    "        super().__init__()\n",
    "        self.unet = KPConv(\n",
    "            architecture=\"unet\", \n",
    "            input_nc=USE_NORMALS * 3, \n",
    "            num_layers=4, \n",
    "            in_grid_size=0.02\n",
    "            )\n",
    "        self.classifier = MultiHeadClassifier(self.unet.output_nc, cat_to_seg)\n",
    "    \n",
    "    @property\n",
    "    def conv_type(self):\n",
    "        \"\"\" This is needed by the dataset to infer which batch collate should be used\"\"\"\n",
    "        return self.unet.conv_type\n",
    "    \n",
    "    def get_batch(self):\n",
    "        return self.batch\n",
    "    \n",
    "    def get_output(self):\n",
    "        \"\"\" This is needed by the tracker to get access to the ouputs of the network\"\"\"\n",
    "        return self.output\n",
    "    \n",
    "    def get_labels(self):\n",
    "        \"\"\" Needed by the tracker in order to access ground truth labels\"\"\"\n",
    "        return self.labels\n",
    "    \n",
    "    def get_current_losses(self):\n",
    "        \"\"\" Entry point for the tracker to grab the loss \"\"\"\n",
    "        return {\"loss_seg\": float(self.loss_seg)}\n",
    "\n",
    "    def forward(self, data):\n",
    "        self.labels = data.y\n",
    "        self.batch = data.batch\n",
    "        \n",
    "        # Forward through unet and classifier\n",
    "        data_features = self.unet(data)\n",
    "        self.output = self.classifier(data_features.x, data.category)\n",
    "\n",
    "         # Set loss for the backward pass\n",
    "        self.loss_seg = torch.nn.functional.nll_loss(self.output, self.labels)\n",
    "        return self.output\n",
    "\n",
    "    def get_spatial_ops(self):\n",
    "        return self.unet.get_spatial_ops()\n",
    "        \n",
    "    def backward(self):\n",
    "         self.loss_seg.backward() \n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QQ0GUiv22KKb"
   },
   "outputs": [],
   "source": [
    "model = PartSegKPConv(dataset.class_to_segments)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "k9Bx8GjE3Kt7"
   },
   "source": [
    "## The data loaders and CPU pre computing features\n",
    "KPConv is quite demanding on spatial operations such as grid sampling and radius search. On the network loaded here we have 10 KPConv layers on the encoder which means 10 radius search operations with varying number of neighbours. We observed a significant performance gain by moving those operations to the CPU where they can easily be optimised with suitable data structures such as kd-tree. We use [nonaflann](https://github.com/jlblancoc/nanoflann) in the back-end, a 3D optimised kd-tree implementation. Note that this is beneficiary only if you have access to multiple CPU threads.\n",
    "\n",
    "You can decide to precompute those spatial operations by setting the `precompute_multi_scale` parameter to `True` when creating the data loaders. The dataset will mine the model to figure out which spatial operations are required and in which order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "i1k4CNG12RXN"
   },
   "outputs": [],
   "source": [
    "NUM_WORKERS = 4\n",
    "BATCH_SIZE = 16\n",
    "dataset.create_dataloaders(\n",
    "    model,\n",
    "    batch_size=BATCH_SIZE, \n",
    "    num_workers=NUM_WORKERS, \n",
    "    shuffle=True, \n",
    "    precompute_multi_scale=True \n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 182
    },
    "colab_type": "code",
    "id": "4wdS3y_aVIbm",
    "outputId": "5e564d1d-e966-41f3-a95f-ea830ee239f7"
   },
   "outputs": [],
   "source": [
    "sample = next(iter(dataset.train_dataloader))\n",
    "sample.keys"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "A5pl2g9ap1Fq"
   },
   "source": [
    "Our `sample` contains the pre computed spatial information in the `multiscale` (encoder side) and `upsample` (decoder) attrivutes. The decoder pre computing is quite simple and just involves some basic caching for the nearest neighbour interpolation operation. Let's take a look at the encoder side of things first. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 201
    },
    "colab_type": "code",
    "id": "z0o8yQF4FWvV",
    "outputId": "3a17a960-ea06-4037-e224-5e0853a48b9e"
   },
   "outputs": [],
   "source": [
    "sample.multiscale"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "B-2bjfQyqUyH"
   },
   "source": [
    "`sample.multiscale` contains 10 different versions of the input batch, each one of these versions contains the location of the points in `pos` as well as the neighbours of these points in the previous point cloud. We will first look at the points coming out of each downsampling layer (strided convolution), we have 5 of them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 270
    },
    "colab_type": "code",
    "id": "qmR0S-F5qTvV",
    "outputId": "eae3af58-5eed-43c7-d083-54dfde409abb"
   },
   "outputs": [],
   "source": [
    "#@title Successive downsampling {run:\"auto\"}\n",
    "sample_in_batch = 0 #@param {type:\"slider\", min:0, max:5, step:1}\n",
    "ms_data = sample.multiscale \n",
    "num_downsize = int(len(ms_data) / 2)\n",
    "p = pv.Plotter(notebook=True,shape=(1, num_downsize),window_size=[1024,256])\n",
    "for i in range(0,num_downsize):\n",
    "    p.subplot(0, i)\n",
    "    pos = ms_data[2*i].pos[ms_data[2*i].batch == sample_in_batch].numpy()\n",
    "    point_cloud = pv.PolyData(pos)\n",
    "    point_cloud['y'] = pos[:,1]\n",
    "    p.add_points(point_cloud,  show_scalar_bar=False, point_size=3)\n",
    "    p.add_text(\"Layer {}\".format(i+1),font_size=10)\n",
    "    p.camera_position = [-1,5, -10]\n",
    "p.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ONPofTEptKyL"
   },
   "source": [
    "Let's now take one point in a layer (query point) and show its neighbours in the previous layer (support point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "colab_type": "code",
    "id": "-uVmbA_OqDrg",
    "outputId": "bcc6d7af-a4ca-432c-c1bb-0c85936bc32b"
   },
   "outputs": [],
   "source": [
    "#@title Explore Neighborhood {run: \"auto\"}\n",
    "selected_layer = 7 #@param {type:\"slider\", min:1, max:9, step:1}\n",
    "sample_in_batch = 0 #@param {type:\"slider\", min:0, max:5, step:1}\n",
    "point1_id = 46 #@param {type:\"slider\", min:0, max:600, step:1}\n",
    "point2_id =  0#@param {type:\"slider\", min:0, max:600, step:1}\n",
    "\n",
    "p = pv.Plotter(notebook=True,shape=(1, 2),window_size=[1024,412])\n",
    "\n",
    "# Selected layer\n",
    "p.subplot(0, 1)\n",
    "ms_data = sample.multiscale[selected_layer]\n",
    "pos = ms_data.pos[ms_data.batch == sample_in_batch].numpy()\n",
    "nei = ms_data.idx_neighboors[ms_data.batch == sample_in_batch]\n",
    "point_cloud = pv.PolyData(pos)\n",
    "p.add_points(point_cloud,  show_scalar_bar=False, point_size=3,opacity=0.3)\n",
    "p.add_points(pos[point1_id,:],  show_scalar_bar=False, point_size=7.0,color='red')\n",
    "p.add_points(pos[point2_id,:],  show_scalar_bar=False, point_size=7.0,color='green')\n",
    "p.camera_position = [-1,5, -10]\n",
    "\n",
    "# Previous layer\n",
    "p.subplot(0, 0)\n",
    "ms_data = sample.multiscale[selected_layer-1]\n",
    "pos = ms_data.pos[ms_data.batch == sample_in_batch].numpy()\n",
    "point_cloud = pv.PolyData(pos)\n",
    "p.add_points(point_cloud,  show_scalar_bar=False,point_size=3, opacity=0.3)\n",
    "nei_pos = ms_data.pos[nei[point1_id]].numpy()\n",
    "nei_pos = nei_pos[nei[point1_id] >= 0]\n",
    "p.add_points(nei_pos,  show_scalar_bar=False, point_size=3.0,color='red')\n",
    "nei_pos = ms_data.pos[nei[point2_id]].numpy()\n",
    "nei_pos = nei_pos[nei[point2_id] >= 0]\n",
    "p.add_points(nei_pos,  show_scalar_bar=False, point_size=3.0,color='green')\n",
    "p.camera_position = [-1,5, -10]\n",
    "\n",
    "p.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OsOeG_VXW5lC"
   },
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Hg3O2swzW7if"
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import time\n",
    "\n",
    "class Trainer:\n",
    "    def __init__(self,model, dataset, num_epoch = 50, device=torch.device('cuda')):\n",
    "        self.num_epoch = num_epoch\n",
    "        self._model = model\n",
    "        self._dataset=dataset\n",
    "        self.device = device\n",
    "\n",
    "    def fit(self):\n",
    "        self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001)\n",
    "        self.tracker = self._dataset.get_tracker(False, True)\n",
    "\n",
    "        for i in range(self.num_epoch):\n",
    "            print(\"=========== EPOCH %i ===========\" % i)\n",
    "            time.sleep(0.5)\n",
    "            self.train_epoch()\n",
    "            self.tracker.publish(i)\n",
    "            self.test_epoch()\n",
    "            self.tracker.publish(i)\n",
    "\n",
    "    def train_epoch(self):\n",
    "        self._model.to(self.device)\n",
    "        self._model.train()\n",
    "        self.tracker.reset(\"train\")\n",
    "        train_loader = self._dataset.train_dataloader\n",
    "        iter_data_time = time.time()\n",
    "        with tqdm(train_loader) as tq_train_loader:\n",
    "            for i, data in enumerate(tq_train_loader):\n",
    "                t_data = time.time() - iter_data_time\n",
    "                iter_start_time = time.time()\n",
    "                self.optimizer.zero_grad()\n",
    "                data.to(self.device)\n",
    "                self._model.forward(data)\n",
    "                self._model.backward()\n",
    "                self.optimizer.step()\n",
    "                if i % 10 == 0:\n",
    "                    self.tracker.track(self._model)\n",
    "\n",
    "                tq_train_loader.set_postfix(\n",
    "                    **self.tracker.get_metrics(),\n",
    "                    data_loading=float(t_data),\n",
    "                    iteration=float(time.time() - iter_start_time),\n",
    "                )\n",
    "                iter_data_time = time.time()\n",
    "\n",
    "    def test_epoch(self):\n",
    "        self._model.to(self.device)\n",
    "        self._model.eval()\n",
    "        self.tracker.reset(\"test\")\n",
    "        test_loader = self._dataset.test_dataloaders[0]\n",
    "        iter_data_time = time.time()\n",
    "        with tqdm(test_loader) as tq_test_loader:\n",
    "            for i, data in enumerate(tq_test_loader):\n",
    "                t_data = time.time() - iter_data_time\n",
    "                iter_start_time = time.time()\n",
    "                data.to(self.device)\n",
    "                self._model.forward(data)           \n",
    "                self.tracker.track(self._model)\n",
    "\n",
    "                tq_test_loader.set_postfix(\n",
    "                    **self.tracker.get_metrics(),\n",
    "                    data_loading=float(t_data),\n",
    "                    iteration=float(time.time() - iter_start_time),\n",
    "                )\n",
    "                iter_data_time = time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xFaIG1SBchHg"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(model, dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 95,
     "referenced_widgets": [
      "8f58dfcaa0dd4dc0b4671c8350199388",
      "efaf0403696444919d2a454b1a77d28b",
      "d854f496052a4bf29a3459c487fe5304",
      "8d4e84720a394f44b0994e3788285878",
      "4504ad204d3241a6ba2f71cd4021a028",
      "e54acd47d4f8465da63ff2a5d425c42c",
      "3d23c649ccf7408c85905baa03a144aa",
      "93fbdcc2d2094c37a36c34b1d9c769df"
     ]
    },
    "colab_type": "code",
    "id": "ydhhP0CdcqO1",
    "outputId": "0888b8ae-e57d-4aac-e53b-e30067bc1dac"
   },
   "outputs": [],
   "source": [
    "trainer.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Y69S3OsL2UKL"
   },
   "outputs": [],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir lightning_logs/version_4/ # Change for your log location"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tiXF_nwx7VCl"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "machine_shape": "hm",
   "name": "PartSegmentationKPConv.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
